import numpy as np 
from helpers_DS import check_if_color_are_covered
from scipy.spatial.distance import euclidean


def GF_to_GFDS(df,  x, color_flag, clust_indices_GF, num_clusters, num_colors, centerUpperBound, centerLowerBound):
	x_sol = np.reshape(x , (-1,num_clusters)) 
	cluster_sizes = np.zeros((num_clusters))
	# find cluster sizes 
	cluster_sizes = x_sol.sum(axis=0)

	emptycluster_flag = False 

	# find and delete the empty clusters 
	empty_clusters = [] 
	for cluster in range(num_clusters):
		if  cluster_sizes[cluster] ==0:
			emptycluster_flag = True 
			empty_clusters.append(cluster) 



	for counter, empty in enumerate(empty_clusters): 
		# delete empty clusters 
		x_sol = np.delete(x_sol, np.s_[empty-counter], axis=1)  

		# make all centers active 
		del clust_indices_GF[empty-counter] 

	# find the number of active centers 
	num_clusters_active = num_clusters-len(empty_clusters)


	# find color sizes in the clusters 
	color_cluster_sizes = np.zeros((num_colors,num_clusters_active))
	for color in range(num_colors):
		color_indices = [i for i, x in enumerate(color_flag) if x ==color]
		x_color_indices =  x_sol[color_indices,:]
		color_cluster_sizes[color,:] = x_color_indices.sum(axis=0)



	# make sure that there is no clusteer that lacks a point of some color 
	for clust in range(num_clusters_active):
		for color in range(num_colors):
			# bi-criteria algorithm should be sufficiently good 
			assert color_cluster_sizes[color,clust] >0 



	# sh counts the number of centers of each color 
	sh =[0]*num_colors

	# list of new centers 
	Q = {key: [] for key in range(num_clusters_active)}

	# colors that can be picked 
	color_cluster_sizes_picked = np.copy(color_cluster_sizes)

	ColorsNotCovered = True 

	# Part 1: 
	for clust in range(num_clusters_active):

		# check if colors are covered or not 
		ColorsNotCovered = check_if_color_are_covered(sh,centerLowerBound)

		if ColorsNotCovered:
			for color in range(num_colors):

				# if this color has not yet reached its lower bound 
				if sh[color] < centerLowerBound[color]: 

					points_in_clust=  np.squeeze(np.argwhere(x_sol[:,clust]==1))

					# loop over the cluster and find a point of the right color 
					for point in points_in_clust: 
						if color_flag[point]==color:
							sh[color] = sh[color] + 1 
							Q[clust].append(point) 
							color_cluster_sizes_picked[color,clust] = color_cluster_sizes_picked[color,clust]-1
							break # for point in points_in_clust: 

	

					break  # for color in range(num_colors):




		else:
			for color in range(num_colors):

				# if we can add one  more center of this color 
				if sh[color]+1 < centerUpperBound[color]: 
					points_in_clust=  np.squeeze(np.argwhere(x_sol[:,clust]==1))

					# loop over the cluster and find a point of the right color 
					for point in points_in_clust: 
						if color_flag[point]==color:
							sh[color] = sh[color] + 1 
							Q[clust].append(point) 
							color_cluster_sizes_picked[color,clust] = color_cluster_sizes_picked[color,clust]-1
							break # for point in points_in_clust: 

					break  # for color in range(num_colors):



	# check if colors are covered or not 
	ColorsNotCovered = check_if_color_are_covered(sh,centerLowerBound)

	# Part 2
	# if there is still a color which has not been covered 
	if ColorsNotCovered: 
		for color in range(num_colors):
			# if this color has not yet reached its lower bound 
			if sh[color] < centerLowerBound[color]: 
				# 1-find a cluster that still has a point of this color 
				# 2-make sure you don't pick the same center again 
				for clust in range(num_clusters_active):
					if color_cluster_sizes_picked[color,clust] >0: 
						get_picked_centers_here = Q[clust]

						# get points in the cluster 
						points_in_clust=  np.squeeze(np.argwhere(x_sol[:,clust]==1))

						# loop over the cluster and find a point of the right color 
						for point in points_in_clust: 
							if color_flag[point]==color and point not in get_picked_centers_here:
								sh[color] = sh[color] + 1 
								Q[clust].append(point) 
								color_cluster_sizes_picked[color,clust] = color_cluster_sizes_picked[color,clust]-1
								break # for point in points_in_clust: 

						break  # for color in range(num_colors):




	# Part 3: get all of the new centers 
	double_GF_clust_indices = []
	for clust in range(num_clusters_active):
		for element in Q[clust]: 
			double_GF_clust_indices.append(element)

	num_clusters_new= len(double_GF_clust_indices)
	num_points , dummy  = np.shape(x_sol)

	x_new = np.zeros((num_points,num_clusters_new))

	# Part 4: Loop over divide 
	new_clust_index = 0 
	for clust in range(num_clusters_active):
		size_Q = len(Q[clust])


		if size_Q ==1: 
			x_new[:,new_clust_index] = x_sol[:,clust]
			new_clust_index += 1 

		else:


			# Just to verify the run 
			#assert 1==0 
			# center index to start with 
			firstIndex = 0 

			
			# color_sizes is H x |Q| and has the amount of color for each center 
			color_sizes = np.zeros((num_colors,size_Q)) 
			startIndex = 0 
			for h in range(num_colors):
				start = startIndex
				# color_division_size 
				T_h = color_cluster_sizes[h,clust]/size_Q

				# extra reminders 
				b_h = color_cluster_sizes[h,clust] - size_Q*np.floor(T_h)


				for counter in range(size_Q):
					if b_h>0: 
						color_sizes[h,start] = np.ceil(T_h) 
						b_h = b_h -1 
						startIndex += (start+1) % size_Q 
					else:
						color_sizes[h,start] = np.floor(T_h) 

					start = (start+1) % size_Q 


			for h in range(num_colors):
				lastpoint =0 
				for q in range(size_Q):
					q_h_color=0 

					for i in range(num_points): 
						if i >= lastpoint and color_flag[i] ==h and x_sol[i,clust]==1 and q_h_color < color_sizes[h,q]:
							x_new[i,q+new_clust_index]=1 
							q_h_color += 1 
							lastpoint = i+1  


			new_clust_index = new_clust_index+size_Q



	x_assignment_doubly_GF = x_new.ravel().tolist()


	# find the clustering cost 
	clustering_cost = 0 


	for i in range(num_points):
		center = np.squeeze(np.argwhere(x_new[i,:]==1)) 

		center_vec = df[double_GF_clust_indices[center],:] 
		i_vec = df[i,:]
		dist = euclidean(i_vec, center_vec) 

		if dist >clustering_cost: 
			clustering_cost = dist 





	return double_GF_clust_indices, x_assignment_doubly_GF , clustering_cost , len(double_GF_clust_indices) , emptycluster_flag



